{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a434f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "BRANCH='main'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "developmental-gibraltar",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell\n",
    "\n",
    "# install NeMo\n",
    "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[nlp]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "challenging-pioneer",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections import nlp as nemo_nlp\n",
    "from nemo.utils.exp_manager import exp_manager\n",
    "\n",
    "import os\n",
    "import wget \n",
    "import torch\n",
    "import pytorch_lightning as pl\n",
    "from omegaconf import OmegaConf"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "employed-ethiopia",
   "metadata": {},
   "source": [
    "In the era of super large language models, the traditional \"pre-train, fine-tune\" procedure is replaced by \"pre-train, prompt, and predict\" method as shown in the [survey paper](https://arxiv.org/pdf/2107.13586.pdf). The prompt method is versatile enough to support all kinds of NLP tasks as shown in the following table: \n",
    "\n",
    "<table>\n",
    "    <thead>\n",
    "        <tr>\n",
    "            <th>Type</th>\n",
    "            <th>Task</th>\n",
    "            <th>Input ([X])</th>\n",
    "            <th>Template</th>\n",
    "            <th>Answer([Y])</th>\n",
    "        </tr>\n",
    "    </thead>\n",
    "    <tbody>\n",
    "        <tr>\n",
    "            <td rowspan=3>Text CLS</td>\n",
    "            <td>Sentiment</td>\n",
    "            <td>I love this movie.</td>\n",
    "            <td>[X] The movie is [Y]</td>\n",
    "            <td>great<br>fantastic<br>...</td>\n",
    "        </tr>\n",
    "        <tr>\n",
    "            <td>Topics</td>\n",
    "            <td>He prompted the LM.</td>\n",
    "            <td>[X] The text is about [Y]</td>\n",
    "            <td>sports<br>science<br>...</td>\n",
    "        </tr>\n",
    "        <tr>\n",
    "            <td>Intention</td>\n",
    "            <td>What is taxi fare to Denver?</td>\n",
    "            <td>[X] The question is about [Y]</td>\n",
    "            <td>quantity<br>city<br>...</td>\n",
    "        </tr>\n",
    "        <tr>\n",
    "            <td rowspan=1>Text-span CLS</td>\n",
    "            <td>Aspect<br>Sentiment</td>\n",
    "            <td>Poor service but good food.</td>\n",
    "            <td>[X] What about service? [Y]</td>\n",
    "            <td>Bad<br>Terrible<br>...</td>\n",
    "        </tr>\n",
    "        <tr>\n",
    "            <td rowspan=1>Text-pair CLS</td>\n",
    "            <td>NLI</td>\n",
    "            <td>[X1]: An old man with ...<br>[X2]: A man walks ...</td>\n",
    "            <td>Hypothesis: [X1], Premise: [X2], Answer: [Y]</td>\n",
    "            <td>Contradiction<br>Entailment<br>...</td>\n",
    "        </tr>\n",
    "        <tr>\n",
    "            <td rowspan=1>Tagging</td>\n",
    "            <td>NER</td>\n",
    "            <td>[X1]: Mike went to Paris.<br>[X2]: Paris</td>\n",
    "            <td>[X1] [X2] is a [Y]</td>\n",
    "            <td>Yes<br>No<br>...</td>\n",
    "        </tr>\n",
    "        <tr>\n",
    "            <td rowspan=2>Text Generation</td>\n",
    "            <td>Summarization</td>\n",
    "            <td>Las Vegas police ...</td>\n",
    "            <td>[X] TL;DR: [Y]</td>\n",
    "            <td>The victim ...<br>A woman ...<br>...</td>\n",
    "        </tr>\n",
    "        <tr>\n",
    "            <td>Translation</td>\n",
    "            <td>Je vous aime.</td>\n",
    "            <td>French [X] English: [Y]</td>\n",
    "            <td>I love you.<br>I fancy you.<br>...</td>\n",
    "        </tr>\n",
    "    </tbody>\n",
    "</table>\n",
    "\n",
    "In this tutorial, we are going to describe how to use [P-Tuning method](https://arxiv.org/pdf/2103.10385.pdf) , which is one of the prompt engineering methods, to find good prompts for large GPT models. We show it can solve multiple downstream NLP tasks with good performance. P-Tuning leverages few continuous free parameters to serve as prompts fed as the input to the pre-trained language models. Freezing the large language model weights, P-Tuning model can be trained efficiently while delivering stats of art performance. \n",
    "\n",
    "Large Language Model can be trained with [NeMo Megatron](https://github.com/NVIDIA/NeMo/tree/main/examples/nlp/language_modeling), up to multi-billion parameters. In this notebook, we will use the pre-trained 344M GPT model released from NGC.\n",
    "\n",
    "# Task Description\n",
    "P-Tuning method can be applied to solve various NLP tasks. Without losing generality, in this notebook, we are going to use P-Tuning method to solve two NLP tasks: **Sentiment Analysis** task and **Question and Answer** task. \n",
    "\n",
    "**Sentiment Analysis** task is also known as opinion mining or emotion AI. It is a sub-field of NLP that tries to identify and extract opinions within a given text across blogs, reviews, social media, forums, news etc. \n",
    "\n",
    "For instance, **given sentences from news title, is it a good or bad news?**<br>\n",
    "\n",
    "**Question and Answer** task is to find the answer to a question given the context text. \n",
    "\n",
    "For instance, \n",
    "```\n",
    "Context: \n",
    "Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny\\'s Child. Managed by her father, Mathew Knowles, the group became one of the world\\'s best-selling girl groups of all time. Their hiatus saw the release of Beyoncé\\'s debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles \"Crazy in Love\" and \"Baby Boy\".\n",
    "Question:\n",
    "How many Grammy awards did Beyoncé win for her first solo album?\n",
    "```\n",
    "\n",
    "# Dataset\n",
    "We will use [Financial PhraseBank dataset](https://huggingface.co/datasets/financial_phrasebank) for sentiment analysis task and [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) for question and answer task.\n",
    "\n",
    "The [Financial PhraseBank dataset](https://huggingface.co/datasets/financial_phrasebank) contains the sentiments for financial news headlines from the perspective of a retail investor. Further details about the dataset can be found in: Malo, P., Sinha, A., Takala, P., Korhonen, P. and Wallenius, J. (2014): “Good debt or bad debt: Detecting semantic orientations in economic texts.” Journal of the American Society for Information Science and Technology.\n",
    "\n",
    "Here's an example of what an annotated abstract from the corpus looks like:\n",
    "\n",
    "```\n",
    "HELSINKI Thomson Financial - Shares in Cargotec fell sharply in early afternoon trade after the cargo handling group posted a surprise drop in April-June profits , which overshadowed the large number of new orders received during the three months .@negative\n",
    "LONDON MarketWatch -- Share prices ended lower in London Monday as a rebound in bank stocks failed to offset broader weakness for the FTSE 100 .@negative\n",
    "Operating profit fell to EUR 35.4 mn from EUR 68.8 mn in 2007 , including vessel sales gain of EUR 12.3 mn .@negative\n",
    "Sales in Finland decreased by 10.5 % in January , while sales outside Finland dropped by 17 % .@negative\n",
    "```\n",
    "\n",
    "The [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n",
    "\n",
    "Let's download the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "federal-beads",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = \"DATA_DIR\"\n",
    "os.makedirs(DATA_DIR, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c1e1b08",
   "metadata": {},
   "source": [
    "## Downloading Financial Phrase Bank Dataset\n",
    "\n",
    "The dataset is collected by Malo et al. 2014, and can be downloaded from this [link](https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip). The zip file for the Financial Phrase Bank Dataset has been provided for ease of download and use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ad03fc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v10/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v10.zip\n",
    "!unzip FinancialPhraseBank-v10.zip -d {DATA_DIR}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "radical-castle",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you want to see more examples, you can explore the text of the corpus using the file browser to the left, or open files directly, for example typing a command like the following in a code-cell:\n",
    "\n",
    "! head -1 $DATA_DIR/FinancialPhraseBank-v1.0/Sentences_50Agree.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b0f69a9",
   "metadata": {},
   "source": [
    "## Download the SQuAD dataset\n",
    "\n",
    "Download a copy of the dataset (distributed under the CC BY-SA 4.0 license):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27e91c7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json\n",
    "!mv train-v2.0.json {DATA_DIR}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "affected-numbers",
   "metadata": {},
   "source": [
    "## Pre-process Financial Phrase Bank Dataset\n",
    "\n",
    "In this pre-process step, we are going to convert the downloaded dataset into the format that can be used for P-Tuning dataloader. The data is split into 10 folds so we can do 10-fold cross validation. In this notebook, we will use the first fold."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "198287d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import random\n",
    "\n",
    "random.seed(1234)\n",
    "files = ['Sentences_50Agree.txt', 'Sentences_66Agree.txt', 'Sentences_75Agree.txt', 'Sentences_AllAgree.txt']\n",
    "base_dir = DATA_DIR + '/FinancialPhraseBank-v1.0/'\n",
    "files = [base_dir + f for f in files]\n",
    "\n",
    "alllines = []\n",
    "for fn in files:\n",
    "    with open(fn, 'r', encoding=\"ISO-8859-1\") as f:\n",
    "        alllines.extend(f.readlines())\n",
    "\n",
    "random.shuffle(alllines)\n",
    "fold = 10\n",
    "fold_size = len(alllines) // fold\n",
    "\n",
    "chunk_start = list(range(0, 14780, 1478))\n",
    "\n",
    "chunks = []\n",
    "\n",
    "for start_id in chunk_start:\n",
    "    chunks.append(alllines[start_id:start_id+fold_size])\n",
    "\n",
    "def gen_file(data, fold_id, split_type):\n",
    "    filename = \"{}/{}_{}.txt\".format(base_dir, split_type, fold_id)\n",
    "    with open(filename, 'w') as f:\n",
    "        obj = {}\n",
    "        for line in data:\n",
    "            splits = line.split('@')\n",
    "            part1 = splits[0].strip()\n",
    "            part2 = splits[1].strip()\n",
    "            obj['sentence'] = part1\n",
    "            obj['label'] = part2\n",
    "            obj['prompt_tag'] = 'sentiment-task'\n",
    "            f.write(json.dumps(obj)+'\\n')\n",
    "\n",
    "\n",
    "def gen_fold(fold_number):\n",
    "    lists = list(range(fold))\n",
    "    test_id = (fold_number + fold) % fold\n",
    "    val_id = (fold_number + fold - 1) % fold\n",
    "    test_set = chunks[test_id]\n",
    "    val_set = chunks[val_id]\n",
    "    lists.remove(test_id)\n",
    "    lists.remove(val_id)\n",
    "    train_set = []\n",
    "    for idd in lists:\n",
    "        train_set += chunks[idd]\n",
    "    gen_file(train_set, fold_number, 'train')\n",
    "    gen_file(val_set, fold_number, 'validation')\n",
    "    gen_file(test_set, fold_number, 'test')\n",
    "\n",
    "gen_fold(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "graphic-debate",
   "metadata": {},
   "source": [
    "The data is converted to the loss json file. Each line has three keys \"sentence\", \"label\" and \"prompt_tag\". \n",
    "Here are the first two lines of converted data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sound-surgeon",
   "metadata": {},
   "outputs": [],
   "source": [
    "!head -n 2 $DATA_DIR/FinancialPhraseBank-v1.0/train_0.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71556cbe",
   "metadata": {},
   "source": [
    "### Preprocess SQuAD Dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "spectacular-strain",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_name = DATA_DIR + '/train-v2.0.json'\n",
    "with open(file_name, 'r') as f:\n",
    "    data_obj = json.load(f)\n",
    "\n",
    "articles = data_obj['data']\n",
    "test_len = 40\n",
    "validation_len = 40\n",
    "train_len = len(articles) - test_len - validation_len\n",
    "train_records = []\n",
    "validation_records = []\n",
    "test_records = []\n",
    "\n",
    "\n",
    "def get_records(sub_articals, records):\n",
    "    for article in sub_articals:\n",
    "        paragraphs = article['paragraphs']\n",
    "        for paragraph in paragraphs:\n",
    "            qas = paragraph['qas']\n",
    "            context = paragraph['context'].strip()\n",
    "            for qa in qas:\n",
    "                record = {}\n",
    "                record['question'] = qa['question'].strip()\n",
    "                record['context'] = context\n",
    "                if qa['is_impossible']:\n",
    "                    record['label'] = 'NA'\n",
    "                else:\n",
    "                    record['label'] = qa['answers'][0]['text'].strip()\n",
    "                record['prompt_tag'] = 'qa-task'\n",
    "                records.append(json.dumps(record))\n",
    "get_records(articles[:train_len], train_records)\n",
    "get_records(articles[train_len:train_len+validation_len], validation_records)\n",
    "get_records(articles[train_len+validation_len:], test_records)\n",
    "random.shuffle(train_records)\n",
    "random.shuffle(validation_records)\n",
    "random.shuffle(test_records)\n",
    "squad_dir = \"DATA_DIR/squad\"\n",
    "os.makedirs(squad_dir, exist_ok=True)\n",
    "with open(squad_dir+'/train.txt', 'w') as f:\n",
    "    f.write(\"\\n\".join(train_records))\n",
    "with open(squad_dir+'/validation.txt', 'w') as f:\n",
    "    f.write(\"\\n\".join(validation_records))\n",
    "with open(squad_dir+'/test.txt', 'w') as f:\n",
    "    f.write(\"\\n\".join(test_records))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d9e4788",
   "metadata": {},
   "source": [
    "The data is converted to the loss json file. Each line has three keys \"question\", \"context\", \"label\" and \"prompt_tag\". \n",
    "Here are the first two lines of converted data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bcd460d",
   "metadata": {},
   "outputs": [],
   "source": [
    "!head -n 2 {squad_dir}/train.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ba98b18",
   "metadata": {},
   "source": [
    "### Combine the two datasets\n",
    "\n",
    "The P-tune model includes a prompt encoder which is used to generate virtual tokens. Its output can be conditioned on the task tags so the P-tune model supports multiple tasks simultaneously. We are going to mix the Financial phrase bank dataset and SQuAD dataset together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a30f7ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "mix_data_dir = f\"{DATA_DIR}/mix\"\n",
    "os.makedirs(mix_data_dir, exist_ok=True)\n",
    "!cat $DATA_DIR/FinancialPhraseBank-v1.0/train_0.txt {squad_dir}/train.txt | shuf > {mix_data_dir}/train.txt\n",
    "!cat $DATA_DIR/FinancialPhraseBank-v1.0/validation_0.txt {squad_dir}/validation.txt | shuf > {mix_data_dir}/validation.txt\n",
    "!cat $DATA_DIR/FinancialPhraseBank-v1.0/test_0.txt {squad_dir}/test.txt | shuf > {mix_data_dir}/test.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42a7d704",
   "metadata": {},
   "source": [
    "Here are the first two lines of converted data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e373c870",
   "metadata": {},
   "outputs": [],
   "source": [
    "!head -n 2 {mix_data_dir}/train.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3813cc36",
   "metadata": {},
   "source": [
    "## Convert the Megatron-LM Weights to Nemo file\n",
    "\n",
    "P-Tuning method works the best with large GPT language models. From our experiences, models of size 5B or above give good performance. If you already have a large GPT model ready, skip this section. \n",
    "\n",
    "In this example, we will use the pretrained 344M NeMo Megatron GPT model from [Megatron-LM project](https://github.com/NVIDIA/Megatron-LM). To load it in NeMo Megatron, We first need to convert the Megatron-LM checkpoint to the `.nemo` file. Let's download the pretrained model weights and vocabulary file.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82b8e08e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pathlib\n",
    "gpt_file = 'megatron_lm_345m_v0.0.zip'\n",
    "vocab_file = 'gpt2-vocab.json'\n",
    "merge_file = 'gpt2-merge.txt'\n",
    "checkpoint_filename = 'model_optim_rng.pt'\n",
    "\n",
    "if not pathlib.Path(gpt_file).exists():\n",
    "    !wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O $gpt_file\n",
    "    !unzip -o $gpt_file\n",
    "    !wget https://s3.amazonaws.com/models.huggingface.co/bert/$vocab_file -O $vocab_file \n",
    "    !wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt -O $merge_file\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b00ee86",
   "metadata": {},
   "outputs": [],
   "source": [
    "WORK_DIR = \"WORK_DIR\"\n",
    "os.makedirs(WORK_DIR, exist_ok=True)\n",
    "\n",
    "# Prepare the model parameters \n",
    "# download the model's configuration file \n",
    "config_dir = WORK_DIR + '/configs/'\n",
    "MODEL_CONFIG = \"megatron_gpt_config.yaml\"\n",
    "os.makedirs(config_dir, exist_ok=True)\n",
    "if not os.path.exists(config_dir + MODEL_CONFIG):\n",
    "    print('Downloading config file...')\n",
    "    wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/conf/' + MODEL_CONFIG, config_dir)\n",
    "else:\n",
    "    print ('config file is already exists')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ae5a1a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# this line will print the entire config of the model\n",
    "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n",
    "print(config_path)\n",
    "config = OmegaConf.load(config_path)\n",
    "config.model.num_layers = 24\n",
    "config.model.hidden_size = 1024\n",
    "config.model.ffn_hidden_size = 4096\n",
    "config.model.num_attention_heads = 16\n",
    "config.model.tokenizer.vocab_file = vocab_file\n",
    "config.model.tokenizer.merge_file = merge_file\n",
    "config.model.tensor_model_parallel_size = 1\n",
    "config.model.data.data_prefix = ''\n",
    "config.model.max_position_embeddings = 1024\n",
    "config.model.data.seq_length = 1024\n",
    "config.model.encoder_seq_length = 1024\n",
    "config.cfg = {}\n",
    "config.cfg.cfg = config.model\n",
    "with open('hparams.yaml', 'w') as f:\n",
    "    f.write(OmegaConf.to_yaml(config.cfg))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e1beda4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "PWD = os.getcwd()\n",
    "wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py')\n",
    "!python -m torch.distributed.run --nproc_per_node=1 megatron_lm_ckpt_to_nemo.py --checkpoint_folder=$PWD/release/mp_rank_00/ --checkpoint_name=$checkpoint_filename --hparams_file=$PWD/hparams.yaml --nemo_file_path=$PWD/gpt_344m.nemo --model_type=gpt --tensor_model_parallel_size=1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84b455a6",
   "metadata": {},
   "source": [
    "# Model configuration\n",
    "\n",
    "Our P-Tuning text classification model is comprised of the pretrained GPT LM model followed by a prompt encoder layer.\n",
    "\n",
    "The model is defined in a config file which declares multiple important sections. They are:\n",
    "- **model**: All arguments that are related to the Model - language model, token classifier, optimizer and schedulers, datasets and any other related information\n",
    "\n",
    "- **trainer**: Any argument to be passed to PyTorch Lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "speaking-grant",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_CONFIG = \"megatron_ptune_gpt.yaml\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "demanding-ballet",
   "metadata": {},
   "outputs": [],
   "source": [
    "# download the model's configuration file \n",
    "config_dir = WORK_DIR + '/configs/'\n",
    "os.makedirs(config_dir, exist_ok=True)\n",
    "if not os.path.exists(config_dir + MODEL_CONFIG):\n",
    "    print('Downloading config file...')\n",
    "    wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/conf/' + MODEL_CONFIG, config_dir)\n",
    "else:\n",
    "    print ('config file is already exists')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "criminal-outdoors",
   "metadata": {},
   "outputs": [],
   "source": [
    "# this line will print the entire config of the model\n",
    "config_path = f'{WORK_DIR}/configs/{MODEL_CONFIG}'\n",
    "print(config_path)\n",
    "config = OmegaConf.load(config_path)\n",
    "# Note: these are small batch-sizes - increase as appropriate to available GPU capacity\n",
    "config.model.data.train_ds.batch_size=8\n",
    "config.model.data.validation_ds.batch_size=8"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dedicated-effort",
   "metadata": {},
   "source": [
    "# Model Training\n",
    "## Setting up Data within the config\n",
    "\n",
    "Among other things, the config file contains dictionaries called train_ds, validation_ds and test_ds. These are configurations used to setup the Dataset and DataLoaders of the corresponding config.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "informed-purse",
   "metadata": {},
   "outputs": [],
   "source": [
    "# in this tutorial train and dev datasets are located in the same folder, so it is enough to add the path of the data directory to the config\n",
    "config.model.data.train_ds.file_path = DATA_DIR+'/mix/train.txt'\n",
    "config.model.data.validation_ds.file_path = DATA_DIR+'/mix/validation.txt'\n",
    "config.model.data.test_ds.file_path = DATA_DIR+'/mix/test.txt'\n",
    "\n",
    "\n",
    "# if you want to decrease the size of your datasets, uncomment the lines below:\n",
    "# NUM_SAMPLES = 1000\n",
    "# config.model.data.train_ds.num_samples = NUM_SAMPLES\n",
    "# config.model.data.validation_ds.num_samples = NUM_SAMPLES"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3304d29b",
   "metadata": {},
   "source": [
    "## Add the Data Processors to Generate the Prompts\n",
    "\n",
    "To customize different prompts for different tasks, we can configure the `TemplateProcessor` to define the template for the prompt input. The curly brackets defines the variables in the template string. `{v0}`, `{v1}`, `{v2}` indicates the virtual token of length `prompt_encoder.template[0]`, `prompt_encoder.template[1]` and `prompt_encoder.template[2]`. The other variables `{var}` refers to the variables in the data record. For example.\n",
    "\n",
    "Given the data record, **{\"sentence1\": \"And he said, Mama, I'm home.\", \"sentence2\": \"He didn't say a word.\"}** and template list [3, 3, 3], \n",
    "the template string **{v0} Hypothesis: [sentence1], {v1} Premise: [sentence2] {v2} Answer:** will be translated into **<span style=\"color:red\">VVV</span> Hypothesis: And he said, Mama, I'm home.<span style=\"color:red\">VVV</span> Premise: He didn't say a word.<span style=\"color:red\">VVV</span> Answer:**, where <span style=\"color:red\">VVV</span> is virtual token of space 3.\n",
    "\n",
    "Let's configure the proper template for the two dataset we prepared:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5ff646d",
   "metadata": {},
   "outputs": [],
   "source": [
    "  config.model.task_processors = [\n",
    "    {\n",
    "      \"taskname\": \"qa-task\",\n",
    "      \"template\": \"{v0} Context: {context}{v1} Question: {question}?{v2} Answer:\",\n",
    "      \"limit_length_field\": \"content\",\n",
    "    },\n",
    "    {\n",
    "      \"taskname\": \"sentiment-task\",\n",
    "      \"template\": \"{v0}{v1} Sentence: {sentence}{v2} Sentiment:\",\n",
    "      \"limit_length_field\": \"sentence\",\n",
    "    },\n",
    "  ]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a9f0d09",
   "metadata": {},
   "source": [
    "Note each `task_processors` items has 3 fields. Besides the `template` string, the `taskname` refers to the `prompt_tag` in the data record. The `limit_length_field` specifies the which field in the data is going to be cut if the length of the input exceeds the maximum sequence length of the model.\n",
    "\n",
    "Register the data processors with `register_taskdata_processor` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48aa4d7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.nlp.data.glue_benchmark.gpt_ptune_dataset import register_taskdata_processor,  TemplateProcessor\n",
    "\n",
    "\n",
    "for processor_config in config.model.task_processors:\n",
    "    processor = TemplateProcessor(\n",
    "        template=processor_config.template, limit_length_field=processor_config.limit_length_field\n",
    "    )\n",
    "    register_taskdata_processor(processor_config.taskname, processor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "divine-belly",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(OmegaConf.to_yaml(config))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "changed-mauritius",
   "metadata": {},
   "source": [
    "## Building the PyTorch Lightning Trainer\n",
    "\n",
    "NeMo models are primarily PyTorch Lightning modules - and therefore are entirely compatible with the PyTorch Lightning ecosystem.\n",
    "\n",
    "Let's first instantiate a Trainer object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "computational-battlefield",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Trainer config - \\n\")\n",
    "print(OmegaConf.to_yaml(config.trainer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "unique-genre",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin\n",
    "\n",
    "\n",
    "# lets modify some trainer configs\n",
    "# checks if we have GPU available and uses it\n",
    "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n",
    "config.trainer.accelerator = accelerator\n",
    "config.trainer.devices = 1\n",
    "config.trainer.max_epochs = 100\n",
    "config.trainer.val_check_interval=1.0\n",
    "# for PyTorch Native AMP set precision=16\n",
    "config.trainer.precision = 16 if torch.cuda.is_available() else 32\n",
    "\n",
    "# remove distributed training flags\n",
    "config.trainer.strategy = None\n",
    "\n",
    "trainer = pl.Trainer(plugins=[NLPDDPPlugin()], **config.trainer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "overall-literature",
   "metadata": {},
   "source": [
    "## Setting up a NeMo Experiment\n",
    "\n",
    "NeMo has an experiment manager that handles logging and checkpointing for us, so let's use it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "mathematical-portable",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))\n",
    "os.makedirs(WORK_DIR, exist_ok=True)\n",
    "\n",
    "# the exp_dir provides a path to the current experiment for easy access\n",
    "exp_dir = str(exp_dir)\n",
    "exp_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f62ea6cd",
   "metadata": {},
   "source": [
    "We will use the converted `.nemo` file as our LM model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "compact-horse",
   "metadata": {},
   "outputs": [],
   "source": [
    "# add the specified above model parameters to the config\n",
    "# config.model.language_model.pretrained_model_name = PRETRAINED_BERT_MODEL\n",
    "config.model.language_model.nemo_file = 'gpt_344m.nemo'\n",
    "config.model.tensor_model_parallel_size = 1\n",
    "config.exp_manager.checkpoint_callback_params.save_top_k = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "seeing-geometry",
   "metadata": {},
   "source": [
    "Now, we are ready to initialize our model. During the model initialization call, the dataset and data loaders we'll be prepared for training and evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "indoor-france",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.nlp.models.language_modeling.megatron_ptune_gpt_model import MegatronGPTPTuneModel\n",
    "model_ptune = MegatronGPTPTuneModel(cfg=config.model, trainer=trainer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "genuine-pipeline",
   "metadata": {},
   "source": [
    "## Monitoring training progress\n",
    "Optionally, you can create a Tensorboard visualization to monitor training progress.\n",
    "If you're not using Colab, refer to [https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks](https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks) if you're facing issues with running the cell below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "changed-expense",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    from google import colab\n",
    "    COLAB_ENV = True\n",
    "except (ImportError, ModuleNotFoundError):\n",
    "    COLAB_ENV = False\n",
    "\n",
    "# Load the TensorBoard notebook extension\n",
    "if COLAB_ENV:\n",
    "    %load_ext tensorboard\n",
    "    %tensorboard --logdir {exp_dir}\n",
    "else:\n",
    "    print(\"To use tensorboard, please use this notebook in a Google Colab environment.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "applied-quality",
   "metadata": {},
   "outputs": [],
   "source": [
    "# start model training\n",
    "trainer.fit(model_ptune)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cooperative-michael",
   "metadata": {},
   "source": [
    "# Inference\n",
    "\n",
    "To see how the model performs, we can run model in the inference mode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "classical-scientist",
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's first create a subset of our dev data\n",
    "query_examples = [\n",
    "\n",
    "]\n",
    "results = model_ptune.cuda().ptune_inference(queries=query_examples, batch_size=1, decode_token_len=15)\n",
    "print('The prediction results of some sample queries with the trained model:')\n",
    "for query, result in zip(query_examples, results):\n",
    "    print(f'Query : {query}')\n",
    "    print(f'Predicted label: {result}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "connected-typing",
   "metadata": {},
   "source": [
    "## Training Script\n",
    "\n",
    "If you have NeMo installed locally, you can also train the model with `examples/nlp/text_classification/ptune_text_classification.py`.\n",
    "\n",
    "To run training script, use:\n",
    "```\n",
    "python examples/nlp/text_classification/ptune_text_classification.py \\\n",
    "    trainer.devices=1 \\\n",
    "    model.tokenizer.vocab_file=VOCAB_FILE \\\n",
    "    model.tensor_model_parallel_size=1 \\\n",
    "    model.language_model.nemo_file=gpt_344m.nemo \\\n",
    "    model.train_ds.file_path=TRAIN_FILE \\\n",
    "    model.prompt_encoder.template=[3,3,3] \\\n",
    "    model.train_ds.batch_size=8 \\\n",
    "    model.validation_ds.file_path=VAL_FILE \\\n",
    "    model.test_ds.file_path=TEST_FILE \\\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ddb3960",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
